# 恢复二叉搜索树: https://leetcode-cn.com/problems/recover-binary-search-tree/

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right


# 我自己的写法， 直接递归中序遍历找到两个节点对象，再更改他们的值
class Solution:
    def recoverTree(self, root: Optional[TreeNode]) -> None:
        """
            1. 首先遍历找到需要交换的两个位置节点, 主要是总结出这样的规律： 看官方题解
                分情况：
                （1）当两个交换节点不相邻的时候， 那么会出现两个节点不符合 二叉搜索树的排序顺序， aj > aj + 1
                （2）当两个交换节点相邻时，那么只有 一个节点不符合 aj > aj + 1

                可以验证这样比较的正确性，且不会遗漏掉某些可能的边界交换。 两个值 一个在前，一个在后， 范围应该是 [0, n - 1], [1, n]。符合上面的 aj > aj + 1的两数范围。
                            1 2 3 4 5
                第一种      1 4 3 2 5
                第二种      1 2 4 3 5
            2. 再次交换进行值的替换
        """
        # x， y 代表交换的两个节点， x_next 表示 x后面的节点
        x, y, x_next = None, None, None
        # 进行 aj > aj + 1 判断时， pre 保存 aj， aj + 1 是当前node。 n 用来计数，看找到几个
        pre, n = None, 0

        # 1. 找到 交替的两个节点
        def inorder(node):
            nonlocal x; nonlocal y; nonlocal x_next
            nonlocal pre; nonlocal n
            if not node: return 

            inorder(node.left)
            # 查找 x， y
            if pre:
                if pre.val > node.val:
                    if n == 0:
                        x = pre
                        x_next = node
                        n += 1
                    elif n == 1:
                        y = node
                        n += 1
            pre = node
            inorder(node.right)
        
        inorder(root)
        if n == 1: y = x_next

        # 2. 进行交换
        x.val, y.val = y.val, x.val



# 优化的写法，代码简洁一些，少了几个变量，且找到x, y 后停止
class Solution:
    def recoverTree(self, root: Optional[TreeNode]) -> None:
        """
            1. 首先遍历找到需要交换的两个位置节点, 主要是总结出这样的规律： 看官方题解
                分情况：
                （1）当两个交换节点不相邻的时候， 那么会出现两个节点不符合 二叉搜索树的排序顺序， aj > aj + 1
                （2）当两个交换节点相邻时，那么只有 一个节点不符合 aj > aj + 1

                可以验证这样比较的正确性，且不会遗漏掉某些可能的边界交换。 两个值 一个在前，一个在后， 范围应该是 [0, n - 1], [1, n]。符合上面的 aj > aj + 1的两数范围。
                            1 2 3 4 5
                第一种      1 4 3 2 5
                第二种      1 2 4 3 5
            2. 再次交换进行值的替换
        """
        # x， y 代表交换的两个节点， x_next 表示 x后面的节点
        x, y = None, None
        # 进行 aj > aj + 1 判断时， pre 保存 aj， aj + 1 是当前node。 n 用来计数，看找到几个
        pre = None

        # 1. 找到 交替的两个节点
        def inorder(node):
            nonlocal x; nonlocal y
            nonlocal pre
            if not node: return 

            inorder(node.left)
            # 查找 x， y
            if pre and pre.val > node.val:
                y = node
                if x == None:
                    x = pre
                else:
                    return 
            pre = node
            inorder(node.right)
        
        inorder(root)
        # 2. 进行交换
        x.val, y.val = y.val, x.val

